Skip to content

Add NVFP4 per-token quantization recipe#3045

Open
cael-ling wants to merge 18 commits into
NVIDIA:mainfrom
cael-ling:feature/nvfp4-per-token-recipe
Open

Add NVFP4 per-token quantization recipe#3045
cael-ling wants to merge 18 commits into
NVIDIA:mainfrom
cael-ling:feature/nvfp4-per-token-recipe

Conversation

@cael-ling

@cael-ling cael-ling commented May 26, 2026

Copy link
Copy Markdown
Contributor

Description

This PR adds an NVFP4 per-token quantization recipe for model pre-training. The default NVFP4BlockScaling recipe computes a single per-tensor outer amax (s_global) per tensor. The per-token variant instead computes a per-row outer amax (length M) for rowwise data and a per-col outer amax (length K) for columnwise data, giving each token/row its own global scale.

Changes

  • Per-token cast kernels: vector-amax + encode/swizzle producing NVFP4 tensors whose _amax_rowwise / _amax_columnwise are per-row/per-col vectors.
  • CUTLASS GEMM (nvfp4_cutlass_per_token_gemm) that rescales with the per-row/per-col outer-amax vectors inside the epilogue;
  • Forward + backward coverage (dgrad NN / wgrad NT layouts).
  • NVFP4PerTokenBlockScaling recipe (re-exported from transformer_engine.pytorch.fp8), plus an equivalent NVTE_NVFP4_PER_TOKEN=1 env-var switch on a plain NVFP4BlockScaling so frameworks that only build a default recipe (e.g. Megatron-Core) can opt in with no code change.
  • Opt-in RHT / SR (per_token_rht / per_token_sr) — off by default on the per-token path.
  • Opt-in 2D weight quantization (per_token_weight_2d): transposition-invariant 16×16 cast emitted in per-token layout.
  • Docs: API reference entry, NVTE_NVFP4_* env-var docs, and a "Per-token NVFP4" feature section with Megatron-Core launch instructions.
  • Example: examples/pytorch/nvfp4_per_token_megatron — single-GPU MoE example comparing per-token vs per-tensor vs BF16 with identical model/data/seed.

Ongoing work

The per-token recipe currently targets accuracy evaluation, not optimized production deployment:

  • Requires NVTE_NORM_FWD_USE_CUDNN=1 (unfused norm forward); the fused norm+amax path rejects per-token quantizers.
  • fuse_wgrad_accumulation=True is unsupported → launch with --no-gradient-accumulation-fusion in Mcore.
  • Forward/backward output quantization, communication/bulk overlap, and CUDA graphs are not yet supported/validated.
  • Kernels are functional but not perf-tuned; use for numerical comparison, not perf benchmarking.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@github-actions github-actions Bot added the community-contribution PRs from external contributor outside the core maintainers, representing community-driven work. label May 26, 2026
Rewrites the grouped multi-tensor cast as a K1 fused amax + K2 fused cast
pair and ships pytest correctness + sweep benches against the per-tensor
RHT+SR production baseline.

  * common/cast/.../quantize_nvfp4_per_token_group.cu: K1+K2 fused
    grouped kernel, reusing the single-tensor 4-stage TMA pipeline.
  * common/gemm/nvfp4_per_token_post_scale.cu: row-wise post-scale
    kernel for the cuBLASLT NVFP4 dequantize step (maybe updated due
    to 2d quant of W).
  * pytorch/csrc/extensions/nvfp4_per_token.cpp + pybind.cpp: new C++
    grouped bulk binding and per-token GEMM entry; thin pybind layer.
  * pytorch/custom_recipes/{gemm_nvfp4_per_token,
    quantization_nvfp4_per_token_group}.py: Python wrappers.
  * tests/pytorch/nvfp4/test_nvfp4_per_token{,_group}.py: byte-equal
    cast tests + bf16-close GEMM tests.
  * tests/pytorch/nvfp4/bench_nvfp4_per_token{,_group}.py: 6x3 sweep
    over M in {1024..32768} x K in {2048,4096,8192}, eager + CUDA
    Graphs columns, ratio against per-tensor RHT+SR baseline.

Co-authored-by: Zhongbo Zhu <zhongboz@nvidia.com>
Signed-off-by: Cael Ling <caell@nvidia.com>
@cael-ling cael-ling force-pushed the feature/nvfp4-per-token-recipe branch from 6f17fe4 to 928ab1c Compare May 27, 2026 13:09
pre-commit-ci Bot and others added 9 commits May 27, 2026 13:10
…uped)

Wire `with_rht` / `random_sign_mask_t` through the per-token K1 (amax)
and K2 (encode) kernels for both single-tensor and grouped paths.
with_rht=False is byte-equal to the pre-RHT code path; when true,
applies a 16-pt RHT on the columnwise direction in both K1 and K2
(rowwise stays raw) with outer amax + inner SF self-consistent.

Implementation: per-thread fp32 FHT on CUDA cores, branchless fp32
sign-bit XOR for the +/-1 sign diagonal, 0.25 normalization folded into
block_amax / block_scale (bit-exact).

Tests cover K1, K2, composite + grouped vs a PyTorch fp32 reference and
byte-equality regressions. Benches gain a --rht flag (2-way default,
3-way under --rht).

Perf vs prod NVFP4Quantizer(rht+sr), Graph mode, 18 shapes M up to 32K:
* single tensor : 0.49x-0.77x (no RHT), 0.59x-0.88x (+RHT)
* grouped (N=8) : 0.41x-0.77x (no RHT), 0.50x-0.94x (+RHT)

Also drops unused THREADS_X_TR / THREADS_Y_TR (nvcc warning NVIDIA#177-D).

Co-authored-by: Zhongbo Zhu <zhongboz@nvidia.com>
Signed-off-by: Cael Ling <caell@nvidia.com>
Add an optional fused-swizzle path to the NVFP4 per-token K2 encode
kernel: when with_swizzle=True the rowwise scale_inv is emitted directly
in the cuBLAS LT 128Mx4K swizzled tile layout, skipping the downstream
nvte_swizzle_scaling_factors launch. The colwise scale_inv stays in the
compact M-major layout (rowwise-only fusion for now).

The new code path is gated by a kWithSwizzle template parameter on
per_token_encode_kernel. The scatter epilogue uses thread mapping
b=tid&3, ty=tid>>2 to give each warp a coalesced 128-byte gmem store,
and packs two K-tiles into one uint64_t SMEM load (2-way bank conflict
instead of 4-way). Pre-existing code path is byte-equal.

with_swizzle is threaded through nvte_nvfp4_per_token_{quantize,encode},
their PyTorch bindings, and the nvfp4_per_token_{quantize,encode} Python
recipes. nvfp4_per_token_gemm takes new a_sf_swizzled / b_sf_swizzled
flags so the caller opts into the fast path per operand (mirrors prod
NVFP4 GEMM's per-operand swizzle).

Add tex.nvfp4_per_token_swizzle_rowwise_sf -- a thin wrapper around
nvte_swizzle_scaling_factors that does one standalone per-operand
swizzle launch. Bench-only; lets --qs attribute swizzle cost separately
from K1+K2 and from cuBLAS LT GEMM.

Bench (bench_nvfp4_per_token.py): add --qs mode (K1+K2 + standalone
swizzle, no GEMM) with two modifiers -- --pair (2 operands, matches one
prod GEMM call's quant+swizzle pipeline) and --fuse (adds a per-token
(fuse) column for the K2-fused path). The existing --swizzle end-to-end
mode also gains the fused-swizzle column. --pair / --fuse auto-imply
--qs to avoid silent fall-through to the default --composite table.

Tests (test_nvfp4_per_token.py): byte-equality of the fused-swizzle
rowwise SF vs a pure-Python permutation reference, byte-equality of all
other outputs (FP4 data, colwise SF, row/col amax) vs with_swizzle=False,
and numerical equivalence of the end-to-end GEMM via both code paths.

Perf at K=N=4096, Graph mode: fused-swizzle path is ~7-35% faster than
the unfused per-token pipeline (--qs) and reaches up to ~2.6x faster
than per-tensor at small M.

Co-authored-by: Zhongbo Zhu <zhongboz@nvidia.com>
Co-authored-by: Jiaxing Qi <jqi@nvidia.com>
Signed-off-by: Cael Ling <caell@nvidia.com>
The per-token cuBLASLt NVFP4 path needs a trailing post-scale kernel
(D *= alpha_a[i] * alpha_b[j]) that is HBM-bound on the M*N output. This
patch ships a forked-CUTLASS NVFP4 GEMM whose EVT epilogue folds the
per-row * per-col rescale into the in-TMEM accumulator -- a single launch
with no separate post-scale, no M*N HBM round-trip.

New C-API entry points (transformer_engine/common/gemm/nvfp4_cutlass_gemm.cu):
  - nvte_nvfp4_cutlass_gemm: scalar (alpha, beta) NVFP4xNVFP4 -> BF16 GEMM
    (CUTLASS analog of the cuBLASLt per-tensor path; used as test ground truth).
  - nvte_nvfp4_cutlass_per_token_gemm: same mainloop, EVT epilogue
    D[i,j] = bf16(NVFP4_DEQUANT_K * alpha_a[i] * alpha_b[j] * acc).
    The outer 1/2688^2 factor (NVFP4 spec) is baked into the EVT explicitly,
    matching the value cuBLASLt auto-folds via its amax slot.

Python bindings (tex.nvfp4_cutlass_gemm / tex.nvfp4_cutlass_per_token_gemm)
plus a/b_sf_swizzled flags for apples-to-apples --gemm-only benching.

Numerical correctness (tests/pytorch/nvfp4/test_nvfp4_cutlass_per_token_gemm.py):
  - fused EVT == cuBLASLt per-token within bf16 ULP (rtol=2e-2), across
    M,N,K = 256..1024.
  - fused EVT with unity alphas == nvfp4_cutlass_gemm(alpha=1/2688^2) BIT-EXACT
    (sanity check that the EVT tree and the baked constant are both correct).

Bench (tests/pytorch/nvfp4/bench_nvfp4_per_token.py --gemm-only) streamlined
to the only comparison that matters for shipping: ct_fused (per-token CUTLASS
fused) vs pten_gemm (prod per-tensor cuBLASLt), with the cf/pten ratio.

Co-authored-by: Zhongbo Zhu <zhongboz@nvidia.com>
Co-authored-by: Jiaxing Qi <jqi@nvidia.com>
Signed-off-by: Cael Ling <caell@nvidia.com>
Extends tests/pytorch/nvfp4/{bench,test}_nvfp4_cutlass_per_token_gemm
with end-to-end forward and backward coverage that aligns the prod
baseline with NVFP4BlockScaling real-ship defaults (input RHT-1D,
weight 2D no-RHT, grad RHT-cols + SR), so per-token (no RHT/SR) is
measured against an actually-shippable prod recipe rather than a
toy quantizer.

bench_nvfp4_per_token.py:
* --e2e-fwd: per-token quant (with_swizzle=True) + fused-EVT CUTLASS
  GEMM vs NVFP4Quantizer + general_gemm (the real nn.Linear fwd
  dispatch). Quant + GEMM inside the timing loop, N = K. Function
  docstring carries an ASCII kernel-pipeline diagram for both paths
  (per-call launch budget: per-token ~5 vs prod ~10).
* --e2e-bwd: real prod nn.Linear.bwd lifecycle. Timing loop = 1 x dY
  quant + dgrad GEMM + wgrad GEMM; X and W are pre-quantized OUTSIDE
  the loop (mirrors prod's reuse of fwd-saved QuantizedTensorStorage,
  bwd never re-quantizes). pten side uses RHT-cols + SR grad
  quantizer + general_gemm NN (dgrad) / NT (wgrad). Function docstring
  carries an ASCII kernel-pipeline diagram (per-step launch budget:
  per-token ~4 vs prod ~12).
* --gemm-only: 3-way table adds an lt_post column (cuBLASLt NVFP4 +
  bf16 per-row*per-col post-scale, "Route 1") next to the existing
  ct_fused fused-EVT path ("Route 2") and the prod pten_gemm
  baseline. Headline ratio lp/cf decides whether to dispatch
  per-token through cuBLASLt + post_scale or fused EVT; current
  data shows ct_fused wins or ties at every shape we care about.

test_nvfp4_cutlass_per_token_gemm.py:
* Layer 2 fwd: per-token quant + fused-EVT GEMM vs BF16 fp32 ground
  truth (rel_l2 < 0.30, robust to per-shape noise).
* Layer 3 fwd: dual-SNR table comparing per-token vs prod, both
  measured against BF16 ground truth, with a per-token-vs-prod ratio.
* Layer 3 bwd: same dual-SNR pattern for dgrad and wgrad. Prod side
  uses real-ship NVFP4BlockScaling grad quantizer (RHT cols + SR);
  per-token side has no RHT/SR (numerical-floor comparison).
* Sanity micro-test for weight 2D quant plumbing through general_gemm
  (catches breakage cheaper than the broader Layer 3 test).

Co-authored-by: Zhongbo Zhu <zhongboz@nvidia.com>
Co-authored-by: Jiaxing Qi <jqi@nvidia.com>
DIVUP_TO_MULTIPLE(buff_elems_total_in * sizeof(IType), TMA_SHMEM_ALIGNMENT);
constexpr int dshmem_size = buff_size_aligned_in + TMA_SHMEM_ALIGNMENT; // + align pad

dim3 grid(static_cast<unsigned>(K / CHUNK_DIM_X), static_cast<unsigned>(M / CHUNK_DIM_Y), 1);

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe use DIVUP here to handle the remainder case?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This fast path has a hard precondition that M and K are exact multiples of CHUNK_DIM (128): validate() does NVTE_CHECK(M % CHUNK_DIM_Y == 0) / NVTE_CHECK(K % CHUNK_DIM_X == 0), and is_supported() returns false unless both hold — so any non-multiple shape is rejected / routed to the generic per-token fallback before it ever reaches this launcher.

// After all 4 stages, emit one atomicMaxFloat per row slot + one per col slot.
//
// kWithRht=true: col-wise amax over RHT-rotated 16-row strips (per-thread
// FHT with random_sign_mask_t). Row direction never sees RHT.

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

typo: Row direction never sees RHT -> Row direction never uses RHT

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch

}
}
#else
NVTE_DEVICE_ERROR("Per-token amax kernel requires SM 10.0+ (Blackwell).");

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For these quantization kernel, TMA only require SM 9.0+ only. Is there any other constraints that limit to sm 10.0+?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The CUDA_ARCH >= 1000 guard is intentional but not because of a hardware op in this kernel. Two reasons:

  1. The shared TE PTX wrappers it calls — cp_async_bulk_tensor_2d_global_to_shared and mbarrier_wait_parity_acquire_cta_shared_cta in util/ptx.cuh — are themselves guarded to >= 1000 and emit NVTE_DEVICE_ERROR below that. They were authored/validated only for the Blackwell path.
  2. The whole NVFP4 quantize path is host-gated to SM100 anyway (NVTE_ERROR("NVFP4 requires SM100 ...")), since NVFP4 is a Blackwell datatype and the downstream FP4 GEMM that consumes these scales only exists on SM100. So the amax kernel is never launched off <SM100; the per-arch guard just yields a clean error instead of an undefined symbol.

cael-ling and others added 8 commits June 2, 2026 01:28
Add NN/NT GEMM layout dispatch so the per-token NVFP4 path covers dgrad
and wgrad, and let per-token opt into RHT via
NVFP4PerTokenBlockScaling(per_token_rht=...) while SR/2D stay disabled
(kernels unimplemented at this commit). Extends the per-token CUTLASS
GEMM, the torch NVFP4Quantizer, and the NVFP4Tensor plumbing, plus
dgrad/wgrad numerical tests and a fwd+bwd module smoke test.

Co-authored-by: Zhongbo Zhu <zhongboz@nvidia.com>
Co-authored-by: Jiaxing Qi <jqi@nvidia.com>
Signed-off-by: Cael Ling <caell@nvidia.com>
Thread a Philox rng_state and a kWithSr template flag through the
per-token encode kernel (rowwise + colwise) and the
nvte_nvfp4_per_token_encode/quantize C-API, mirroring the per-tensor
SR path. Drop the SR mutex check in the torch NVFP4Quantizer and build
the rng_state when stochastic rounding is requested. Add a per_token_sr
recipe flag on NVFP4PerTokenBlockScaling wired through the quantizer
factory, plus statistical tests (SR unbiasedness -- lower RMSE than RN
when averaged -- and RN-determinism / SR-nondeterminism) folded into
test_nvfp4_per_token.py.

Co-authored-by: Zhongbo Zhu <zhongboz@nvidia.com>
Co-authored-by: Jiaxing Qi <jqi@nvidia.com>
Signed-off-by: Cael Ling <caell@nvidia.com>
Wire with_sr + rng_state through the grouped per-token C-API and cast
dispatch, implement the SR FP4 cast in the grouped kernel, and drop the
"per-token does not support SR" guard. Also fix two comment typos
(sees -> uses) in quantize_nvfp4_per_token.cu per review.

Co-authored-by: Zhongbo Zhu <zhongboz@nvidia.com>
Co-authored-by: Jiaxing Qi <jqi@nvidia.com>
Signed-off-by: Cael Ling <caell@nvidia.com>
Introduce NVTE_NVFP4_PER_TOKEN_WEIGHT_2D (recipe.per_token_weight_2d),
default off so the per-token path stays byte-equal. When enabled, only the
forward WEIGHT switches to the per-tensor 2D cast (16x16 inner tile + scalar
outer amax) re-dressed in per-token tensor layout: the scalar outer amax is
broadcast across the per-row/col alpha vectors and the inner SF is the same
16-row-replicated 2D tile, so the existing per-token CUTLASS GEMM consumes it
unchanged with no kernel modification. Activation/gradient casts stay
per-token 1D.

Co-authored-by: Jiaxing Qi <jqi@nvidia.com>
Co-authored-by: Zhongbo Zhu <zhongboz@nvidia.com>
Signed-off-by: Cael Ling <caell@nvidia.com>
Document the user-facing surface of the NVFP4 per-token recipe and add a
runnable single-GPU example so the recipe can be exercised end to end.

- docs/api/common.rst: list NVFP4PerTokenBlockScaling in the API reference.
- docs/envvars.rst: document the NVTE_NVFP4_* knobs -- per-token activation
  (NVTE_NVFP4_PER_TOKEN) plus the RHT/SR/weight-2D opt-ins, and the
  per-tensor disable flags.
- docs/features/.../nvfp4.rst: add a "Per-token NVFP4" section explaining the
  per-row/per-col outer-amax cast, its differences from the per-tensor default
  (RHT/SR off by default, forced-off knobs, unfused-norm requirement), and how
  to launch it with Megatron-Core.
- recipe/__init__.py: document the per_token_rht/per_token_sr/per_token_weight_2d
  constructor kwargs and drop the stale "stochastic rounding unsupported" note.
- pytorch/fp8.py: re-export NVFP4PerTokenBlockScaling.
- examples/pytorch/nvfp4_per_token_megatron: single-GPU MoE example (run +
  sbatch + job-chain scripts and README) comparing per-token vs per-tensor vs
  BF16 with identical model/data/seed.

Co-authored-by: Zhongbo Zhu <zhongboz@nvidia.com>
Co-authored-by: Jiaxing Qi <jqi@nvidia.com>
Signed-off-by: Cael Ling <caell@nvidia.com>
@cael-ling cael-ling marked this pull request as ready for review June 11, 2026 07:58
@greptile-apps

greptile-apps Bot commented Jun 11, 2026

Copy link
Copy Markdown
Contributor

Greptile Summary

This PR adds an NVFP4 per-token quantization recipe to Transformer Engine, replacing the per-tensor outer amax with per-row (length M) and per-col (length K) amax vectors. The implementation spans new CUDA cast kernels (K1 vector-amax + K2 FP4 encode), a fused CUTLASS EVT GEMM that folds the per-row/per-col outer-scale vectors directly into the bf16 epilogue, recipe classes (NVFP4PerTokenBlockScaling, env-var activation via NVTE_NVFP4_PER_TOKEN), and plumbing through the quantizer, tensor storage, and GEMM dispatch layers.

  • New kernels and GEMM: quantize_nvfp4_per_token.cu + quantize_nvfp4_per_token_group.cu implement the composite K1+K2 per-token cast; nvfp4_cutlass_gemm.cu adds the fused EVT per-token GEMM that replaces the legacy cuBLAS LT + post-scale pass.
  • Recipe and dispatch: NVFP4PerTokenBlockScaling (and env-var activation of NVFP4BlockScaling) forces mutual-exclusion with 2D quantization, row-scaled activation, and 4over6; _nvfp4_per_token_gemm in gemm.py dispatches TN/NN/NT layouts to the fused kernel, including an N-D shape restore for batch-flattened activations.
  • Feature flags and limitations: per-token RHT, SR, and 2D weight quantization are opt-in; fuse_wgrad_accumulation, output quantization, comm-overlap, and CUDA graphs are explicitly unsupported and raise at runtime.

Confidence Score: 3/5

The PR adds a large new quantization path (13k+ lines) described as an accuracy-evaluation MVP. The new code is well-guarded for the intended end-to-end flow, but the weight-2D path in quantizer.cpp has a null-pointer in the amax restore that, while currently unreachable, lives adjacent to defensively-written null checks and should be fixed before the code grows more callers.

The weight-2D amax restore in quantizer.cpp calls out.set_amax(rowwise_amax_ptr, ...) immediately after picking amax_ptr as the non-null fallback. If rowwise_amax_ptr is null, nvte_quantize_v2 would crash reading the amax. The path is shielded today by construction invariants, but the surrounding code already includes defensive nullptr checks that the critical line does not match. Combined with several explicitly documented MVP limitations, the change needs targeted fixes before it can be considered fully production-ready.

transformer_engine/pytorch/csrc/quantizer.cpp (per-token weight-2D amax restore), transformer_engine/pytorch/cpp_extensions/gemm.py (alpha forwarding), transformer_engine/pytorch/custom_recipes/quantization_nvfp4_per_token_group.py (0-token split handling)

Important Files Changed

Filename Overview
transformer_engine/pytorch/csrc/quantizer.cpp Major additions: per-token NVFP4 dispatch in create_tensor/convert_and_update_tensor/quantize_impl; a latent null-pointer in the weight-2D amax restore path (guarded by construction invariant but not defensive); logic is otherwise well-structured with mutex checks.
transformer_engine/pytorch/cpp_extensions/gemm.py New _nvfp4_per_token_gemm dispatch and grouped-GEMM loop; layout dispatch table (TN/NN/NT) and N-D shape restoration logic are correct; alpha scalar is silently ignored for the per-token path; overall routing is sound.
transformer_engine/common/recipe/init.py Adds NVFP4PerTokenBlockScaling subclass, per-token env-var activation via NVTE_NVFP4_PER_TOKEN, and mutex-forcing _force_per_token_settings; classmethod polymorphic dispatch and post_init flow are correct.
transformer_engine/pytorch/custom_recipes/quantization_nvfp4_per_token.py Reference and production per-token quantize wrappers; arithmetic chain mirrors per-tensor NVFP4 reference; shape allocation and CUDA kernel arguments look correct.
transformer_engine/pytorch/custom_recipes/gemm_nvfp4_per_token.py Production per-token GEMM wrapper; dequant reference math and validation checks are consistent; beta!=0 guard correctly prevents unsupported accumulate path.
transformer_engine/pytorch/custom_recipes/quantization_nvfp4_per_token_group.py Grouped per-token quantize wrapper; correctly delegates to C++ bulk binding; validation rejects M_i=0 which breaks unbalanced MoE scenarios where some experts receive zero tokens.
transformer_engine/pytorch/tensor/nvfp4_tensor.py Adds per_token flag to NVFP4Quantizer/NVFP4Tensor; propagates through FSDP2 metadata, _ViewFunc, _ReshapeFunc, and reduce_ex; mutex checks in Python constructor mirror C++; FSDP2 tuple ordering change is correctly matched in unpack.
transformer_engine/pytorch/csrc/extensions/nvfp4_cutlass_gemm.cpp New nvfp4_cutlass_per_token_gemm binding; input validation (dtype, contiguity, shape, amax vector sizes) is thorough; SF swizzle logic mirrors the scalar-alpha entry point.
transformer_engine/pytorch/csrc/extensions/nvfp4_per_token.cpp Per-token quantize C++ bindings (composite K1+K2, K1-only, K2-only, grouped); shape validation and TensorWrapper assembly are correct; mode dispatch (0/1/2) is clearly documented.

Sequence Diagram

sequenceDiagram
    participant Recipe as NVFP4PerTokenBlockScaling
    participant Q as NVFP4Quantizer
    participant K1 as nvte_nvfp4_per_token_quantize
    participant Tensor as NVFP4Tensor
    participant GEMM as _nvfp4_per_token_gemm
    participant CGEMM as nvfp4_cutlass_per_token_gemm
    Recipe->>Q: "per_token=True rowwise=True columnwise=True"
    Q->>Q: amax_rowwise(M,) amax_columnwise(K,)
    Q->>K1: quantize input
    K1-->>Tensor: data(M,K/2) sf(M,K/16) amax(M,)
    K1-->>Tensor: col_data(K,M/2) col_sf(K,M/16) col_amax(K,)
    GEMM->>CGEMM: TN/NN/NT layout dispatch
    CGEMM-->>GEMM: "D[i,j]=bf16(amax_a[i]*amax_b[j]*(A@B^T)[i,j])"
Loading

Reviews (1): Last reviewed commit: "Add docs and Megatron-Core example for t..." | Re-trigger Greptile

Comment on lines +2399 to +2413
NVTE_CHECK(amax_ptr != nullptr, "Could not find amax pointer for per-token weight-2D.");

// 1. Single scalar tensor amax -> amax[0] (mirror the per-tensor no-RHT path:
// treat the buffer as length 1 for the reduction, then fan out to both
// rowwise/columnwise amax[0]).
out.set_amax(amax_ptr, DType::kFloat32, std::vector<size_t>{1});
NVTE_SCOPED_GIL_RELEASE(
{ nvte_compute_amax_with_config(input.data(), out.data(), w2d_config, stream); });
out.set_amax(rowwise_amax_ptr, DType::kFloat32, std::vector<size_t>{1});
if (rowwise_amax_ptr != amax_ptr && rowwise_amax_ptr != nullptr) {
NVTE_CHECK_CUDA(cudaMemcpyAsync(rowwise_amax_ptr, amax_ptr, sizeof(float),
cudaMemcpyDeviceToDevice, stream));
}
if (columnwise_amax_ptr != amax_ptr && columnwise_amax_ptr != nullptr) {
NVTE_CHECK_CUDA(cudaMemcpyAsync(columnwise_amax_ptr, amax_ptr, sizeof(float),

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Latent null-pointer dereference in per-token weight-2D amax restore

After nvte_compute_amax_with_config writes the global amax into amax_ptr[0], the code restores the output tensor's amax field with out.set_amax(rowwise_amax_ptr, ...). If rowwise_amax_ptr is nullptr (i.e., the quantizer was constructed with rowwise=False), this sets the output's amax descriptor to a null pointer. The immediately following nvte_quantize_v2 then tries to read amax[0] to derive S_enc and will crash.

Currently this path is unreachable because per_token_weight_2d is only set for weight quantizers, and all weight quantizers in the recipe are constructed with rowwise=True, columnwise=True. However, the guard in step 3 (if (rowwise_amax_ptr != nullptr && w2d_rows > 1)) shows the author anticipated both pointers could be null, while the critical out.set_amax call on this line does not. Using amax_ptr (the non-null pointer already validated by the NVTE_CHECK above) would be safe in all configurations: out.set_amax(amax_ptr, DType::kFloat32, std::vector<size_t>{1}).

Comment on lines +507 to +530
# Per-token NVFP4 dispatches to fused EVT GEMM that consumes per-row
# (M,) and per-col (N,) outer-amax vectors directly. cuBLASLt cannot,
# so this MUST short-circuit before the row-scaled-or-generic fork.
if _is_nvfp4_per_token_tensor(A) or _is_nvfp4_per_token_tensor(B):
if not (_is_nvfp4_per_token_tensor(A) and _is_nvfp4_per_token_tensor(B)):
raise NotImplementedError(
"NVFP4 per-token GEMM requires both A and B to be per-token tensors. "
"Mixing per-token + prod NVFP4 in one GEMM is not supported."
)
out = _nvfp4_per_token_gemm(
A,
B,
transa=transa,
transb=transb,
out=out,
out_dtype=out_dtype,
bias=bias,
grad=grad,
accumulate=accumulate,
gelu=gelu,
quantization_params=quantization_params,
ub=ub,
extra_output=extra_output,
)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 alpha scalar silently ignored for per-token GEMM

general_gemm validates and stores alpha in kwargs["alpha"], but the per-token short-circuit path dispatches to _nvfp4_per_token_gemm which has no alpha parameter and never forwards the value. The C++ binding nvfp4_cutlass_per_token_gemm also lacks a global scalar alpha argument — only the per-row/per-col alpha_a/alpha_b vectors are supported. For all current TE module call sites alpha=1.0 is the invariant, so numerical output is unaffected today. If a caller ever passes alpha != 1.0 through general_gemm with per-token tensors, the result will be silently wrong instead of raising an error.

Comment on lines +47 to +51
for i, M_i in enumerate(split_sections):
if M_i <= 0:
raise ValueError(f"split_sections[{i}] must be > 0, got {M_i}")
if M_i % _PER_TOKEN_TILE != 0:
raise ValueError(f"split_sections[{i}] = {M_i} must be a multiple of {_PER_TOKEN_TILE}")

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Public grouped-quantize API unconditionally rejects 0-token splits

split_sections[i] <= 0 raises ValueError, but in MoE training with dynamic token routing, experts commonly receive zero tokens in a given micro-batch. The general_grouped_gemm per-token loop already handles this by skipping the launch when m_splits[i] == 0, so the GEMM side is fine. If users call this Python wrapper directly (e.g., from bench scripts or custom MoE quantization pipelines), they must pre-filter empty experts. A comment or guard skipping allocation for empty splits would make the API usable in unbalanced-routing scenarios.

Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

community-contribution PRs from external contributor outside the core maintainers, representing community-driven work.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants